#include "ttc.h"

using namespace std;


LidarCamFusion::LidarCamFusion(ros::NodeHandle &nh)
{   
    marker_center = vector<cv::Point2d>(20);

    nh.getParam("cam_topic", cam_topic_);
    nh.getParam("lidar_topic", lidar_topic_);
    nh.getParam("marker_topic", marker_topic_);

    ROS_INFO("cam_topic : %s",cam_topic_.c_str());
    ROS_INFO("lidar_topic : %s",lidar_topic_.c_str());
    ROS_INFO("marker_topic : %s",marker_topic_.c_str());

    P_rect_00 = cv::Mat(3,4,cv::DataType<double>::type);
    R_rect_00 = cv::Mat(4,4,cv::DataType<double>::type);
    RT = cv::Mat(4,4,cv::DataType<double>::type);

    RT.at<double>(0,0) = 7.533745e-03; RT.at<double>(0,1) = -9.999714e-01; RT.at<double>(0,2) = -6.166020e-04; RT.at<double>(0,3) = -4.069766e-03;
    RT.at<double>(1,0) = 1.480249e-02; RT.at<double>(1,1) = 7.280733e-04; RT.at<double>(1,2) = -9.998902e-01; RT.at<double>(1,3) = -7.631618e-02;
    RT.at<double>(2,0) = 9.998621e-01; RT.at<double>(2,1) = 7.523790e-03; RT.at<double>(2,2) = 1.480755e-02; RT.at<double>(2,3) = -2.717806e-01;
    RT.at<double>(3,0) = 0.0; RT.at<double>(3,1) = 0.0; RT.at<double>(3,2) = 0.0; RT.at<double>(3,3) = 1.0;
    
    R_rect_00.at<double>(0,0) = 9.999239e-01; R_rect_00.at<double>(0,1) = 9.837760e-03; R_rect_00.at<double>(0,2) = -7.445048e-03; R_rect_00.at<double>(0,3) = 0.0;
    R_rect_00.at<double>(1,0) = -9.869795e-03; R_rect_00.at<double>(1,1) = 9.999421e-01; R_rect_00.at<double>(1,2) = -4.278459e-03; R_rect_00.at<double>(1,3) = 0.0;
    R_rect_00.at<double>(2,0) = 7.402527e-03; R_rect_00.at<double>(2,1) = 4.351614e-03; R_rect_00.at<double>(2,2) = 9.999631e-01; R_rect_00.at<double>(2,3) = 0.0;
    R_rect_00.at<double>(3,0) = 0; R_rect_00.at<double>(3,1) = 0; R_rect_00.at<double>(3,2) = 0; R_rect_00.at<double>(3,3) = 1;

    P_rect_00.at<double>(0,0) = 7.215377e+02; P_rect_00.at<double>(0,1) = 0.000000e+00; P_rect_00.at<double>(0,2) = 6.095593e+02; P_rect_00.at<double>(0,3) = 0.000000e+00;
    P_rect_00.at<double>(1,0) = 0.000000e+00; P_rect_00.at<double>(1,1) = 7.215377e+02; P_rect_00.at<double>(1,2) = 1.728540e+02; P_rect_00.at<double>(1,3) = 0.000000e+00;
    P_rect_00.at<double>(2,0) = 0.000000e+00; P_rect_00.at<double>(2,1) = 0.000000e+00; P_rect_00.at<double>(2,2) = 1.000000e+00; P_rect_00.at<double>(2,3) = 0.000000e+00; 

    PRT = cv::Mat(3,1,cv::DataType<double>::type);
    PRT = P_rect_00 * R_rect_00 * RT;

    ifstream classNamesFile(yoloClassesFile);
    if(classNamesFile.is_open())
    {
        string className = "";
        while(getline(classNamesFile, className))
            classNamesVec.push_back(className);
    }

    // net = load_network((char*)yoloModelConfiguration.c_str(), (char*)yoloModelWeights.c_str(), 0 );
    // set_batch_network(net, 1);

    // image_transport::ImageTransport it(nh);
    // pub_imgdetec = it.advertise("camDetec",10);


    cam_sub_.subscribe(nh, cam_topic_, 10);
    lidar_sub_.subscribe(nh, lidar_topic_, 10);
    marker_sub_.subscribe(nh,marker_topic_,10);



    // （1）激光雷达和图像融合
    typedef message_filters::sync_policies::ApproximateTime<sensor_msgs::Image, sensor_msgs::PointCloud2> cam_lidar_fuse_policy;
    typedef message_filters::Synchronizer<cam_lidar_fuse_policy> Sync;
    boost::shared_ptr<Sync> sync;
    sync.reset(new Sync(cam_lidar_fuse_policy(10), cam_sub_, lidar_sub_));
    sync->registerCallback(boost::bind(&LidarCamFusion::fusionCallback, this, _1, _2));

   
    
    ROS_INFO("[3]+LIDAR camera fusion node start \n");

    ros::MultiThreadedSpinner spinner(8); // Use 4 threads
    spinner.spin();
}
LidarCamFusion::~LidarCamFusion(){
    free_network(net);
}



void LidarCamFusion::yolov3Detec(cv::Mat& visImg)
{
    cv::Mat rgbImg;
    cvtColor(visImg, rgbImg,COLOR_BGR2RGB);
    float* srcImg;
    size_t srcSize = rgbImg.rows*rgbImg.cols*3*sizeof(float);
    srcImg=(float*)malloc(srcSize);

    imgConvert(rgbImg,srcImg);

    float* resizeImg;
    size_t resizeSize = net->w * net->h * 3 * sizeof(float);
    resizeImg=(float*)malloc(resizeSize);
    imgResize(srcImg, resizeImg, visImg.cols, visImg.rows, net->w, net->h);  

    network_predict(net, resizeImg);
    
    int nboxes=0;
    
    float nms = 0.35;
    detection *dets=get_network_boxes(net, rgbImg.cols, rgbImg.rows, thresh, 0.5, 0,1, &nboxes);

    if(nms)
    {
        do_nms_sort(dets,nboxes,classes,nms);
    }


    boxes.clear();
    classNames.clear();
    for(int i=0; i<nboxes; i++)
    {   
        thresh = 0.5;
        bool flag= 0;
        int className;
        for(int j = 0; j < classes; j++)
        {
            if(dets[i].prob[j] > thresh)
            {
                thresh = dets[i].prob[j];
                flag = 1;
                className = j;
            }
        }
        if(flag && className == 2)
        {
            int left = (dets[i].bbox.x - dets[i].bbox.w/2.)*visImg.cols;
            int right = (dets[i].bbox.x + dets[i].bbox.w/2.)*visImg.cols;
            int top = (dets[i].bbox.y - dets[i].bbox.h/2.)*visImg.rows;
            int bot = (dets[i].bbox.y + dets[i].bbox.h/2.)*visImg.rows;

            if(left < 0)                left = 0;
            if(right > visImg.cols -1)   right = visImg.cols - 1;
            if(top < 0)                 top = 0;
            if(bot > visImg.rows-1)      bot = visImg.rows-1;

            Rect box(left, top, fabs(left-right),fabs(bot-top));
            boxes.push_back(box);
            classNames.push_back(className);
        }
    }

    boxes_updated = true;

    free_detections(dets,nboxes); 
    
    for(unsigned int i=0; i < boxes.size(); i++)
    {
        int offset = classNames[i] * 123457 % 80;
        float red   = 255*get_color(2,offset,80);
        float green = 255*get_color(1,offset,80);
        float blue  = 255*get_color(0,offset,80);

        rectangle(visImg,boxes[i],Scalar(blue,green,red),2);                            

        
        String label = String(classNamesVec[classNames[i]]);
        int baseLine = 0;
        Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5,1, &baseLine);
        // putText(visImg,label, Point(boxes[i].x,boxes[i].y+labelSize.height),FONT_HERSHEY_SIMPLEX, 1,Scalar(red,blue,green),2);
        
        cv::Point2d temp_box(boxes[i].x+boxes[i].width/2.0, boxes[i].y+boxes[i].height/2.0);
        box_center.push_back(temp_box);
    }
    
    free(srcImg);
    free(resizeImg);
}
void LidarCamFusion::fusionCallback(const sensor_msgs::ImageConstPtr& image, const sensor_msgs::PointCloud2ConstPtr& cloud)
{
    ROS_INFO("~~~~~~");
    // cout << "image_time: " << image->header.stamp << "\n"  << "lidar_time: " << cloud->header.stamp << "\n"  << endl;

    /* 1---image process */
    cv_bridge::CvImagePtr cv_ptr;
    cv_ptr = cv_bridge::toCvCopy(image, sensor_msgs::image_encodings::BGR8);
    cv::Mat cvimg = cv_ptr->image;
    cv::Mat visImg; 
    visImg = cvimg.clone();

    /* 2---lidar process */ 
    pcl::PointCloud<pcl::PointXYZ>::Ptr pcl_cloud (new pcl::PointCloud<pcl::PointXYZ>);
    pcl::fromROSMsg(*cloud, *pcl_cloud);
    pcl::PassThrough<pcl::PointXYZ> pass;
    pass.setInputCloud (pcl_cloud);
    pass.setFilterFieldName ("x");
    pass.setFilterLimits (0.0, 100);
    pass.filter (*pcl_cloud);

    std::vector<LidarPoint> lidarpoints;
    for(size_t i = 0; i< pcl_cloud->points.size(); i++){
        LidarPoint p;
        p.x = pcl_cloud->points[i].x;
        p.y = pcl_cloud->points[i].y;
        p.z = pcl_cloud->points[i].z;
        lidarpoints.push_back(p);
    }


    // show raw image
    // std::string windowName = "Image Raw";
    // cv::namedWindow( windowName, 1 );
    // cv::imshow( windowName, visImg );
    // cv::waitKey(1);
    
    // yoloV3 detection
    // yolov3Detec(visImg);

    DataFrame frame;
    frame.cameraImg = visImg;
    frame.lidarPoints = lidarpoints;
    dataBuffer.push_back(frame);
    string yoloBasePath = "./";
    cout <<  "dataBuffer.size() = " << dataBuffer.size() << endl;
    

    detectObjects((dataBuffer.end() - 1)->cameraImg, (dataBuffer.end() - 1)->boundingBoxes, 0.2, 0.4,
                      yoloBasePath, yoloClassesFile, yoloModelConfiguration, yoloModelWeights, true);

    
    cv::Mat imgGray;
    cv::cvtColor((dataBuffer.end()-1)->cameraImg, imgGray, cv::COLOR_BGR2GRAY);

    
    vector<cv::KeyPoint> keypoints; 
    string detectorType = "SHITOMASI";

    detKeypointsShiTomasi(keypoints, imgGray, false);
    
    bool bLimitKpts = true;
    if (bLimitKpts)
    {
        int maxKeypoints = 50;
        keypoints.erase(keypoints.begin() + maxKeypoints, keypoints.end());       
        cv::KeyPointsFilter::retainBest(keypoints, maxKeypoints);
        
    }

    (dataBuffer.end() - 1)->keypoints = keypoints;

    cv::Mat descriptors;
    string descriptorType = "BRISK"; 
    descKeypoints((dataBuffer.end() - 1)->keypoints, (dataBuffer.end() - 1)->cameraImg, descriptors, descriptorType);

    (dataBuffer.end() - 1)->descriptors = descriptors;


    if (dataBuffer.size() > 2) 
    {

        vector<cv::DMatch> matches;
        string matcherType = "MAT_BF";        // MAT_BF, MAT_FLANN
        string descriptorType = "DES_BINARY"; // DES_BINARY, DES_HOG
        string selectorType = "SEL_NN";       // SEL_NN, SEL_KNN

        matchDescriptors((dataBuffer.end() - 2)->keypoints, (dataBuffer.end() - 1)->keypoints,
                            (dataBuffer.end() - 2)->descriptors, (dataBuffer.end() - 1)->descriptors,
                            matches, descriptorType, matcherType, selectorType);

        (dataBuffer.end() - 1)->kptMatches = matches;

        cout << "#7 : MATCH KEYPOINT DESCRIPTORS done" << endl;

        map<int, int> bbBestMatches;
        matchBoundingBoxes(matches, bbBestMatches, *(dataBuffer.end()-2), *(dataBuffer.end()-1)); // associate bounding boxes between current and previous frame using keypoint matches
        
        (dataBuffer.end()-1)->bbMatches = bbBestMatches;

        

        // loop over all BB match pairs
        for (auto it1 = (dataBuffer.end() - 1)->bbMatches.begin(); it1 != (dataBuffer.end() - 1)->bbMatches.end(); ++it1)
        {            
            // find bounding boxes associates with current match
            BoundingBox *prevBB, *currBB;
            for (auto it2 = (dataBuffer.end() - 1)->boundingBoxes.begin(); it2 != (dataBuffer.end() - 1)->boundingBoxes.end(); ++it2)
            {
                if (it1->second == it2->boxID) // check wether current match partner corresponds to this BB
                {
                    currBB = &(*it2);
                    cout << "111111" << endl;
                }
            }

            for (auto it2 = (dataBuffer.end() - 2)->boundingBoxes.begin(); it2 != (dataBuffer.end() - 2)->boundingBoxes.end(); ++it2)
            {
                if (it1->first == it2->boxID) // check wether current match partner corresponds to this BB
                {
                    prevBB = &(*it2);
                    cout << "222222" << endl;
                }
            }

            // compute TTC for current match
            if( currBB->lidarPoints.size()>0 && prevBB->lidarPoints.size()>0 ) // only compute TTC if we have Lidar points
            {
                cout << "1111" << endl;

                double ttcLidar; 
                computeTTCLidar(prevBB->lidarPoints, currBB->lidarPoints, 10, ttcLidar);

                double ttcCamera;
                clusterKptMatchesWithROI(*currBB, (dataBuffer.end() - 2)->keypoints, (dataBuffer.end() - 1)->keypoints, (dataBuffer.end() - 1)->kptMatches);
                computeTTCCamera((dataBuffer.end() - 2)->keypoints, (dataBuffer.end() - 1)->keypoints, currBB->kptMatches, 10, ttcCamera);


                cv::Mat visImg2 = (dataBuffer.end() - 1)->cameraImg.clone();
                // showLidarImgOverlay(visImg, currBB->lidarPoints, P_rect_00, R_rect_00, RT, &visImg2);

                vector<cv::Point3f> msg_;
                for(size_t i = 0; i < currBB->lidarPoints.size(); ++i){
                    cv::Point3f pt( currBB->lidarPoints[i].x,  currBB->lidarPoints[i].y, currBB->lidarPoints[i].z);
                    msg_.push_back(pt);
                }
                vector<cv::Point2d> lidarPoints2d;
                cv::Size imageSize;
                imageSize.width = 640;
                imageSize.height = 480;

                lidarPoints2d = pointcloud2_to_image(msg_, imageSize);
                for (int i = 0; i< lidarPoints2d.size(); ++i){
                    cv::circle(visImg2, lidarPoints2d[i], 10, cv::Scalar(255, 255, 0), -1);
                }
                cv::rectangle(visImg2, cv::Point(currBB->roi.x, currBB->roi.y), cv::Point(currBB->roi.x + currBB->roi.width, currBB->roi.y + currBB->roi.height), cv::Scalar(0, 255, 0), 2);
                
                char str[200];
                sprintf(str, "TTC Lidar : %f s, TTC Camera : %f s", ttcLidar, ttcCamera);
                putText(visImg2, str, cv::Point2f(80, 50), cv::FONT_HERSHEY_PLAIN, 2, cv::Scalar(0,0,255));
                string windowName11 = "融合检测结果";
                cv::namedWindow(windowName11, 4);
                cv::imshow(windowName11, visImg2);
                cout << "Press key to continue to next frame" << endl;
                cv::waitKey(0);
                

            }
        }       

    // dataBuffer.pop_front();   
    }
    

}

/* 实车velo16线和相机的投影 */
vector<Point2d> LidarCamFusion::pointcloud2_to_image(const vector<Point3f> msg,
                                         const cv::Size& imageSize)
{
    int w = imageSize.width;
    int h = imageSize.height;

    vector<Point2d> img_points;
    cv::Mat intrinsic_matrix = cv::Mat::eye(3, 3, CV_32FC1);
    cv::Mat distCoeff = cv::Mat(5, 1, DataType<float>::type);
    cv::Mat cameraExtrinsicMat = cv::Mat(4, 4, DataType<float>::type);

    intrinsic_matrix.at<float>(0, 0) = 7.1086713154276424e+02;//748.513026;//1.6415318549788924e+003;
    intrinsic_matrix.at<float>(1, 0) = 0;
    intrinsic_matrix.at<float>(2, 0) = 0;
    intrinsic_matrix.at<float>(0, 1) = 0;
    intrinsic_matrix.at<float>(1, 1) = 7.1002928273752718e+02;//746.151191;//1.7067753507885654e+003;
    intrinsic_matrix.at<float>(2, 1) = 0;
    intrinsic_matrix.at<float>(0, 2) = 3.4648360876958566e+02;//344.761595;//5.3262822453148601e+002;
    intrinsic_matrix.at<float>(1, 2) = 2.1356982249423606e+02;//239.344274;//3.8095355839052968e+002;
    intrinsic_matrix.at<float>(2, 2) = 1.0;

    distCoeff.at<float>(0) = -1.2009796794520683e-01;// -0.181816;//-7.9134632415085826e-001;
    distCoeff.at<float>(1) = -1.7909877827451678e-01;//0.033556;//1.5623584435644169e+000;
    distCoeff.at<float>(2) = 1.4330751500649312e-03;//-0.007749;//-3.3916502741726508e-002;
    distCoeff.at<float>(3) = -1.6035897606267967e-03;//-0.000288;//-1.3921577146136694e-002;
    distCoeff.at<float>(4) = 2.4196197710843320e-01;//0.0;//1.1430734623697941e-002;

    cameraExtrinsicMat.at<float>(0, 0) = -5.3421013207097412e-02;//-2.2885227522464435e-01;
    cameraExtrinsicMat.at<float>(0, 1) = -2.9226997203065541e-02;//-1.5214255212709749e-01;
    cameraExtrinsicMat.at<float>(0, 2) = 9.9814426711894688e-01;//9.6149845551449376e-01;
    cameraExtrinsicMat.at<float>(0, 3) = 1.2616455554962158e-01;//8.4437644109129906e-03;
    cameraExtrinsicMat.at<float>(1, 0) = -9.9851694441775540e-01;//-9.7303107309702697e-01;
    cameraExtrinsicMat.at<float>(1, 1) = 1.2067176116371425e-02;//6.3942072996323596e-03;
    cameraExtrinsicMat.at<float>(1, 2) = -5.3087615987213732e-02;//-2.3058543948102478e-01;
    cameraExtrinsicMat.at<float>(1, 3) = -1.5259952284395695e-02;//-2.1137388423085213e-02;
    cameraExtrinsicMat.at<float>(2, 0) = -1.0493191056894891e-02;//2.8933836803155533e-02;
    cameraExtrinsicMat.at<float>(2, 1) = -9.9949995792649671e-01;//-9.8833787640930915e-01;
    cameraExtrinsicMat.at<float>(2, 2) = -2.9828292716396088e-02;//-1.4950275964872894e-01;
    cameraExtrinsicMat.at<float>(2, 3) = -5.2367319585755467e-04;//2.1896010637283325e-01;
    cameraExtrinsicMat.at<float>(3, 0) = 0.0;
    cameraExtrinsicMat.at<float>(3, 1) = 0.0;
    cameraExtrinsicMat.at<float>(3, 2) = 0.0;
    cameraExtrinsicMat.at<float>(3, 3) = 1.0;
    
    cv::Mat invRt;
    cv::Mat invTt = cv::Mat(3, 1, cv::DataType<double>::type);
    invRt = cameraExtrinsicMat(cv::Rect(0, 0, 3, 3));
    cv::Mat invT = -invRt.t() * (cameraExtrinsicMat(cv::Rect(3, 0, 1, 3)));
    invTt = invT.t();

    cv::Mat point(1, 3, CV_64F);
    cv::Point2d imagepoint;

    for(int size = 0; size < msg.size(); ++size)
    {
        for (int i = 0; i < 3; i++)
        {
            point.at<double>(i) = invTt.at<float>(i);
            // std::cout<<"point1 = "<<point<<std::endl;
            for (int j = 0; j < 3; j++)
            {
                if(j == 0) {point.at<double>(i) += double(msg[size].x) * invRt.at<float>(j, i);}
                else if(j == 1) {point.at<double>(i) += double(msg[size].y) * invRt.at<float>(j, i);}
                else {point.at<double>(i) += double(msg[size].z) * invRt.at<float>(j, i);}
            }
        }

        double tmpx = point.at<double>(0) / point.at<double>(2);
        double tmpy = point.at<double>(1) / point.at<double>(2);
        double r2 = tmpx * tmpx + tmpy * tmpy;
        double tmpdist =
            1 + distCoeff.at<float>(0) * r2 + distCoeff.at<float>(1) * r2 * r2 + distCoeff.at<float>(4) * r2 * r2 * r2;

        imagepoint.x =
            tmpx * tmpdist + 2 * distCoeff.at<float>(2) * tmpx * tmpy + distCoeff.at<float>(3) * (r2 + 2 * tmpx * tmpx);
        imagepoint.y =
            tmpy * tmpdist + distCoeff.at<float>(2) * (r2 + 2 * tmpy * tmpy) + 2 * distCoeff.at<float>(3) * tmpx * tmpy;
        imagepoint.x = intrinsic_matrix.at<float>(0, 0) * imagepoint.x + intrinsic_matrix.at<float>(0, 2);
        imagepoint.y = intrinsic_matrix.at<float>(1, 1) * imagepoint.y + intrinsic_matrix.at<float>(1, 2);

        int px = int(imagepoint.x + 0.5);
        int py = int(imagepoint.y + 0.5);
        Point2d tmp_pt(px, py);
        img_points.push_back(tmp_pt);
    }
    
    return img_points;
}



float LidarCamFusion::get_color(int c, int x, int max)          //set the color of lines
{
    float ratio = ((float)x/max)*5;
    int i=floor(ratio);
    int j=ceil(ratio);
    ratio -= i;
    float r=(1-ratio)*colors[i][c]+ratio*colors[j][c];
    return r;
}

int main(int argc, char **argv)
{
    ros::init(argc, argv, "LidarCamFusion");
    ros::NodeHandle nh("~");
    LidarCamFusion core(nh);

}